In [1]:
%matplotlib inline
import nengo.spa as spa
import numpy as np
import matplotlib.pyplot as plt
First, here's the SPA power function:
In [2]:
def power(s, e):
x = np.fft.ifft(np.fft.fft(s.v) ** e).real
return spa.SemanticPointer(data=x)
Here are two helper functions for computing the dot product over space, and for plotting the results
In [5]:
def spatial_dot(v, X, Y, Z, xs, ys, transform=1):
if isinstance(v, spa.SemanticPointer):
v = v.v
vs = np.zeros((len(ys),len(xs)))
for i,x in enumerate(xs):
for j, y in enumerate(ys):
hx = 2/3 * y
hy = (np.sqrt(3)/3 * x - y/3 )
hz = -(np.sqrt(3)/3 * x + y/3 )
t = power(X, hx)*power(Y,hy)*power(Z, hz)*transform
vs[j,i] = np.dot(v, t.v)
return vs
def spatial_plot(vs, colorbar=True, vmin=-1, vmax=1, cmap='plasma'):
vs = vs[::-1, :]
plt.imshow(vs, interpolation='none', extent=(xs[0],xs[-1],ys[0],ys[-1]), vmax=vmax, vmin=vmin, cmap=cmap)
if colorbar:
plt.colorbar()
In [7]:
D = 64
X = spa.SemanticPointer(D)
X.make_unitary()
Y = spa.SemanticPointer(D)
Y.make_unitary()
Z = spa.SemanticPointer(D)
Z.make_unitary()
xs = np.linspace(-3, 3, 50)
ys = np.linspace(-3, 3, 50)
So, that lets us take a vector and turn it into a spatial map. Now let's try going the other way around: specify a desired map, and find the vector that gives that.
In [8]:
desired = np.zeros((len(xs),len(ys)))
for i,x in enumerate(xs):
for j, y in enumerate(ys):
if 0<x<2 and -1<y<=3:
val = 1
else:
val = 0
desired[j, i] = val
spatial_plot(desired)
This can be treated as a least-sqares minimization problem. In paticular, we're trying to build the above map using a basis space. The basis vectors in that space are the spatial maps of the D unit vectors in our vector space!! So let's compute those, and use our standard nengo solver:
In [9]:
A = np.array([spatial_dot(np.eye(D)[i], X, Y, Z, xs, ys).flatten() for i in range(D)])
In [10]:
import nengo
v, info = nengo.solvers.LstsqL2(reg=0)(np.array(A).T, desired.flatten())
In [12]:
vs = spatial_dot(v, X, Y, Z, xs, ys)
rmse = np.sqrt(np.mean((vs-desired)**2))
print(rmse)
spatial_plot(vs)
Yay!
However, one possible problem with this approach is that the norm of this vector is unconstrained:
In [13]:
np.linalg.norm(v)
Out[13]:
A better solution would add a constraint on the norm. For that, we use cvxpy
In [14]:
import cvxpy as cvx
class CVXSolver(nengo.solvers.Solver):
def __init__(self, norm_limit):
super(CVXSolver, self).__init__(weights=False)
self.norm_limit = norm_limit
def __call__(self, A, Y, rng=np.random, E=None):
N = A.shape[1]
D = Y.shape[1]
d = cvx.Variable((N, D))
error = cvx.sum_squares(A * d - Y)
cvx_prob = cvx.Problem(cvx.Minimize(error), [cvx.norm(d) <= self.norm_limit])
cvx_prob.solve()
decoder = d.value
rmses = np.sqrt(np.mean((Y-np.dot(A, decoder))**2, axis=0))
return decoder, dict(rmses=rmses)
In [15]:
v2, info2 = CVXSolver(norm_limit=10)(np.array(A).T, desired.flatten().reshape(-1,1))
v2.shape = D,
In [17]:
vs2 = spatial_dot(v2, X, Y, Z, xs, ys)
rmse2 = np.sqrt(np.mean((vs2-desired)**2))
print('rmse:', rmse2)
spatial_plot(vs2)
print('norm:', np.linalg.norm(v2))
Looks like the accuracy depends on what limit we put on the norm. Let's see how that varies:
In [18]:
plt.figure(figsize=(10,4))
limits = np.arange(10)+1
for i, limit in enumerate(limits):
plt.subplot(2, 5, i+1)
vv, _ = CVXSolver(norm_limit=limit)(np.array(A).T, desired.flatten().reshape(-1,1))
s = spatial_dot(vv.flatten(), X, Y, Z, xs, ys)
error = np.sqrt(np.mean((s-desired)**2))
spatial_plot(s, colorbar=False)
plt.title('norm: %g\nrmse: %1.2f' % (limit, error))
plt.xticks([])
plt.yticks([])
Looks like it works fine with norms that aren't too large.
In [19]:
import seaborn
SA = np.array(A).T # A matrix passed to solver
gamma = SA.T.dot(SA)
U, S, V = np.linalg.svd(gamma)
w = int(np.sqrt(D))
h = int(np.ceil(D // w))
plt.figure(figsize=(16, 16))
for i in range(len(U)):
# the columns of U are the left-singular vectors
vs = spatial_dot(U[:, i], X, Y, Z, xs, ys)
plt.subplot(w, h, i+1)
spatial_plot(vs, colorbar=False, vmin=None, vmax=None, cmap=seaborn.diverging_palette(150, 275, s=80, l=55, as_cmap=True))
plt.title(r"$\sigma_{%d}(A^T A) = %d$" % (i+1, S[i]))
plt.xticks([])
plt.yticks([])
plt.show()
In [ ]: